package com.example.nextgen.parser.sql;

import cn.hutool.core.util.StrUtil;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import com.example.nextgen.util.database.ColumnEntity;
import com.example.nextgen.util.database.TableEntity;
import org.springframework.stereotype.Component;

import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
 * @author martin
 * @date Created in 2022/12/12 18:04
 * @description
 */
@Component
public class SQLParser {

    public SQLParseResult parse(String sql, Map<String, TableEntity> dbInfo) {

        MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
        SQLStatement sqlStatement = SQLUtils.parseSingleMysqlStatement(sql);
        sqlStatement.accept(visitor);
        List<TableStat.Column> selectColumns = visitor.getColumns().stream()
                .filter(TableStat.Column::isSelect).collect(Collectors.toList());
        List<TableStat.Condition> conditions = visitor.getConditions().stream()
                .filter(condition -> condition.getValues().size() > 0 && condition.getValues().get(0) == null)
                .collect(Collectors.toList());
        List<Table> tableList = visitor.getTables().keySet().stream().map(tableStat -> {
            TableEntity tableEntity = dbInfo.get(tableStat.getName());
            Table table = new Table();
            table.setTableName(tableEntity.getTableName());
            table.setTableComment(tableEntity.getTableComment());
            table.setParameter(StrUtil.toCamelCase(table.getTableName()));
            return table;
        }).collect(Collectors.toList());

        List<ConditionParam> conditionParams = new ArrayList<>();
        conditions.forEach(condition -> {
            String columnName = condition.getColumn().getName();
            ColumnEntity columnEntity = dbInfo.get(condition.getColumn().getTable()).getColumns().get(columnName);
            String javaPropName = StrUtil.toCamelCase(columnName);
            ConditionParam conditionParam = new ConditionParam();
            conditionParam.setParameter(javaPropName);
            conditionParam.setDescription(columnEntity.getColumnComment());
            conditionParam.setTableName(columnEntity.getTableName());
            conditionParam.setTableComment(columnEntity.getTableComment());
            conditionParams.add(conditionParam);
        });

        Operation operation = getOperation(sqlStatement);

        List<QueryParam> queryParams = new ArrayList<>();
        if (operation == Operation.query) {
            selectColumns.forEach(column -> {
                String columnName = column.getName();
                ColumnEntity columnEntity = dbInfo.get(column.getTable()).getColumns().get(columnName);
                String javaPropName = StrUtil.toCamelCase(columnName);
                QueryParam queryParam = new QueryParam();
                queryParam.setParameter(javaPropName);
                queryParam.setDescription(columnEntity.getColumnComment());
                queryParam.setTableName(columnEntity.getTableName());
                queryParam.setTableComment(columnEntity.getTableComment());
                queryParams.add(queryParam);
            });
        }

        SQLParseResult result = new SQLParseResult();
        result.setSource(sql);
        result.setOperation(operation);
        result.setConditionParamList(conditionParams);
        result.setQueryParamList(queryParams);
        result.setTableList(tableList);
        result.setMybatisSql(getMybatisSql(result.getSource()));
        return result;
    }

    private Operation getOperation(SQLStatement sqlStatement) {
        if (sqlStatement instanceof SQLSelectStatement) {
            return Operation.query;
        } else if (sqlStatement instanceof SQLUpdateStatement) {
            return Operation.update;
        } else if (sqlStatement instanceof SQLDeleteStatement) {
            return Operation.remove;
        } else if (sqlStatement instanceof SQLInsertStatement) {
            return Operation.add;
        }
        return null;
    }

    private String getMybatisSql(String source){
        Matcher sqlMatcher = Pattern.compile(".*?\\?").matcher(source);
        Pattern paramReg = Pattern.compile("\\.?(\\w+?) ?=.*");
        final String tpl = "<if test=\"request.${param} != null and request.${param} != ''\">${sql}</if>";
        while (sqlMatcher.find()){
            Matcher paramMatcher = paramReg.matcher(sqlMatcher.group());
            paramMatcher.find();
            String param = paramMatcher.group(1);
            String content = tpl.replace("${param}", param);
            source = source.replace(sqlMatcher.group()
                    ,content.replace("${sql}",sqlMatcher.group().replace("?","#{request."+param+"}")));
        }
        return source;
    }

}
